import numpy as np
import sys
import os
currentdir = os.path.dirname(os.path.realpath(__file__))
sys.path.insert(0, os.path.abspath(os.path.join(currentdir, "..")))

import torch
import torch.nn as nn
from tool.util import set_seed, bool_flag
from datetime import datetime

from tool.args import get_general_args
from tool.util import init_wandb
from train.mlbase import MLBase
from evaluate.evaluator import Evaluator
import torch.nn.functional as F

from data.dl_getter import DATASETS, n_cls, sh, input_range

import pandas as pd
import argparse

from data.dl_getter import get_transform
import torchvision as tv
from torch.utils.data import DataLoader
import torchvision.transforms as tr


total_ds = ['cifar10', 'svhn', 'cifar100', 'interp', 'celeba',
            'N', 'U', 'OODomain', 'Constant']


def get_ood_loader(n_ds):
    if n_ds == 'cifar10':
        ds = tv.datasets.CIFAR10(
            root='~/data', train=True, transform=get_transform('cifar10'), download=False)
    elif n_ds == 'svhn':
        ds = tv.datasets.SVHN(
            root='~/data', split='train', transform=get_transform('svhn'), download=False)
    elif n_ds == 'cifar100':
        ds = tv.datasets.CIFAR100(
            root='~/data', train=True, transform=get_transform('cifar100'), download=False)
    elif n_ds == 'celeba':
        ds = tv.datasets.CelebA(
            root='~/data', split='train', 
            transform=tr.Compose([tr.Resize(32), get_transform('celeba')]), 
            download=False)
    dl = DataLoader(ds, batch_size=100, shuffle=False)
    return dl


@torch.no_grad()
def cal_dist(latent, m):
    batch_size = 1000
    num_batches = latent.size(0) // batch_size
    dist_list = []
    for i in range(num_batches):
        start_idx = i * batch_size
        end_idx = (i + 1) * batch_size
        diff = latent[start_idx:end_idx].unsqueeze(dim=1) - m.unsqueeze(dim=0)
        dist = diff.norm(dim=-1)
        dist_list.append(dist)
    if num_batches * batch_size < latent.size(0):
        diff = latent[num_batches * batch_size:].unsqueeze(dim=1) - m.unsqueeze(dim=0)
        dist = diff.norm(dim=-1)
        dist_list.append(dist)
    dist = torch.cat(dist_list, dim=0)
    return dist


@torch.no_grad()
def cal_logits(model, latent):
    batch_size = 100
    num_batches = latent.size(0) // batch_size
    logits_list = []
    for i in range(num_batches):
        start_idx = i * batch_size
        end_idx = (i + 1) * batch_size
        logits_batch = model.head(latent[start_idx:end_idx])
        logits_list.append(logits_batch)
    if num_batches * batch_size < latent.size(0):
        logits_batch = model.head(latent[num_batches * batch_size:])
        logits_list.append(logits_batch)
    logits = torch.cat(logits_list, dim=0)
    return logits


@torch.no_grad()
def get_latent(model, loader, interp=False, sigma=0.):
    z, lbls = [], []
    if interp:
        for i, (x, y) in enumerate(loader):
            x = x.cuda()
            if i > 0:
                x_mix = (x + last_batch) / 2 + sigma * torch.randn_like(x)
                latent = model.enc(x_mix)
                z.append(latent)
                lbls.append(y)
            last_batch = x     
    else:
        from tqdm import tqdm
        for x, y in tqdm(loader):
            x, y = x.cuda(), y.cuda()
            z.append(model.enc(x))
            lbls.append(y)
    z = torch.cat(z)
    lbls = torch.cat(lbls)
    return z, lbls


@torch.no_grad()
def get_non_natural_latent(model, ds, loader):
    torch.manual_seed(42)
    z = []
    for x, _ in (loader):
        if ds == 'N':
            Ns = torch.randn_like(x).cuda()
            z.append(model.enc(Ns))
        elif ds == 'U':
            Us = torch.empty_like(x).uniform_(-1, 1).cuda()
            z.append(model.enc(Us))
        elif ds == 'OODomain':
            OODomains = torch.empty_like(x).uniform_(-10, 10).cuda()
            z.append(model.enc(OODomains))
        elif ds == 'Constant':
            num_images = x.shape[0]      
            pixels = torch.rand((num_images, 3), dtype=torch.float32)  
            images = torch.ones((num_images, 32, 32, 3), dtype=torch.float32)
            for i in range(num_images):
                images[i] *= pixels[i]
            images = images.permute(0, 3, 1, 2).cuda() 
            z.append(model.enc(images))
    z = torch.cat(z)
    return z


@torch.no_grad()
def check_acc(model, vl_dl):
    # check accuracy
    correct = 0
    model.eval()
    for x, y in vl_dl:
        x, y = x.cuda(), y.cuda()
        pred = model(x)
        correct += (pred.argmax(1) == y).sum().item()
    print(f'Accuracy: {correct / len(vl_dl.dataset)}')


# p
# python evaluate/ana_id_ood.py --wandb_entity eavnjeong --bsz 100 --bsz_vl 100 --head lin --arch resnet34 --exp_load cifar100_lin --method evaluate --dataset cifar100
@torch.no_grad()
def main(model, tr_dl, vl_dl, args):
    check_acc(model, vl_dl)
    z_id, lbls = get_latent(model, tr_dl)

    # ours    
    if hasattr(model.head, 'num'):
        W = model.head.num.ms
        origin = 0
    # wr
    else:
        W = model.head.fc.weight.data
        b = model.head.fc.bias
        W_plus = torch.pinverse(W.T)
        origin = -torch.matmul(W_plus.T, b)

    softmax = nn.Softmax(1)
    lower_bound = np.linspace(0, 1, 21)[:-1]
    # m
    tmp_latent = z_id - origin
    m = [tmp_latent[lbls == lbl].mean(0) for lbl in range(args.n_cls)]
    m = torch.stack(m)

    for ood_ds in total_ds:
        top1_logits_mean, top2_logits_mean = [], []
        top1_logits_std, top2_logits_std = [], []
        counts = []
        z_norm_means, z_norm_stds = [], []
        top1_cos_theta_mean, top1_cos_theta_std = [], []
        top2_cos_theta_mean, top2_cos_theta_std = [], []
        top1_dist_mean, top2_dist_mean = [], []
        top1_dist_std, top2_dist_std = [], []
        
        if ood_ds == args.dataset:
            latent = z_id
        elif ood_ds == 'interp':
            latent, _ = get_latent(model, tr_dl, True)
        elif ood_ds in ['cifar10', 'svhn', 'cifar100', 'celeba']:
            dl = get_ood_loader(ood_ds)
            latent, _ = get_latent(model, dl)
        elif ood_ds in ['N', 'U', 'OODomain', 'Constant']:
            latent = get_non_natural_latent(model, ood_ds, vl_dl)

        latent = latent - origin
        logits = cal_logits(model, latent)
        prob, pred = torch.topk(softmax(logits), 2)
        logits_k, _ = torch.topk(logits, 2)

        print(prob[:, 0].mean().item(), prob[:, 1].mean().item(), \
            prob[:, 0].std().item(), prob[:, 1].std().item())
        # logits, prob
        top1_prob = prob[:, 0].detach().cpu()
        top1_logits = logits_k[:, 0].detach().cpu()
        top2_logits = logits_k[:, 1].detach().cpu()
        # norm
        z_norm = latent.norm(dim=1)
        # cos_theta, zn : (50000, 512)
        zn = F.normalize(latent, dim=1)
        # wn : (100, 512)
        wn = F.normalize(W, dim=1)
        cos_thetas = torch.matmul(zn, wn.transpose(0, 1))
        # dist
        dist = cal_dist(latent, m)

        for lb in lower_bound:
            select_index = (lb < top1_prob) & (top1_prob <= lb + 0.05)
            if select_index.sum().item() == 0:
                top1_logits_mean.append(0); top2_logits_mean.append(0)
                top1_logits_std.append(0); top2_logits_std.append(0)
                z_norm_means.append(0); z_norm_stds.append(0) 
                counts.append(0)
                top1_cos_theta_mean.append(0); top2_cos_theta_mean.append(0)
                top1_cos_theta_std.append(0); top2_cos_theta_std.append(0)
                top1_dist_mean.append(0); top2_dist_mean.append(0)
                top1_dist_std.append(0); top2_dist_std.append(0)
                continue
            top1_pred= pred[:, 0][select_index]
            top2_pred = pred[:, 1][select_index]
            
            top1_logits_mean.append(top1_logits[select_index].mean().item())
            top2_logits_mean.append(top2_logits[select_index].mean().item())
            top1_logits_std.append(top1_logits[select_index].std().item())
            top2_logits_std.append(top2_logits[select_index].std().item())

            z_norm_means.append(z_norm[select_index].mean().item())
            z_norm_stds.append(z_norm[select_index].std().item())
            counts.append(select_index.sum().item())

            top1_cos = torch.gather(cos_thetas[select_index], 1, top1_pred.unsqueeze(dim=1))
            top2_cos = torch.gather(cos_thetas[select_index], 1, top2_pred.unsqueeze(dim=1))
            top1_cos_theta_mean.append(top1_cos.mean().item())
            top2_cos_theta_mean.append(top2_cos.mean().item())
            top1_cos_theta_std.append(top1_cos.std().item())
            top2_cos_theta_std.append(top2_cos.std().item())

            top1_dist = torch.gather(dist[select_index], 1, top1_pred.unsqueeze(dim=1))
            top2_dist = torch.gather(dist[select_index], 1, top2_pred.unsqueeze(dim=1))
            top1_dist_mean.append(top1_dist.mean().item())
            top2_dist_mean.append(top2_dist.mean().item())
            top1_dist_std.append(top1_dist.std().item())
            top2_dist_std.append(top2_dist.std().item())
        save_csv(counts, top1_logits_mean, top2_logits_mean,
                    top1_logits_std, top2_logits_std,
                    z_norm_means, z_norm_stds,
                    top1_cos_theta_mean, top2_cos_theta_mean,
                    top1_cos_theta_std, top2_cos_theta_std,
                    top1_dist_mean, top2_dist_mean,
                    top1_dist_std, top2_dist_std, 
                    os.path.join(args.output_path, f'{args.dataset}(in)_{ood_ds}(out).csv'))


def save_csv(counts, top1_logits_mean, top2_logits_mean,
                top1_logits_std, top2_logits_std,
                z_norm_means, z_norm_stds,
                top1_cos_theta_mean, top2_cos_theta_mean,
                top1_cos_theta_std, top2_cos_theta_std,
                top1_dist_mean, top2_dist_mean,
                top1_dist_std, top2_dist_std, 
                save_path):
    df_counts = pd.DataFrame(counts)
    df_top1_mean = pd.DataFrame([top1_logits_mean]).T
    df_top2_mean = pd.DataFrame([top2_logits_mean]).T
    df_top1_std = pd.DataFrame([top1_logits_std]).T
    df_top2_std = pd.DataFrame([top2_logits_std]).T

    df_norm_mean = pd.DataFrame([z_norm_means]).T
    df_norm_std = pd.DataFrame([z_norm_stds]).T

    df_top1_cos_theta_mean = pd.DataFrame([top1_cos_theta_mean]).T
    df_top2_cos_theta_mean = pd.DataFrame([top2_cos_theta_mean]).T
    df_top1_cos_theta_std = pd.DataFrame([top1_cos_theta_std]).T
    df_top2_cos_theta_std = pd.DataFrame([top2_cos_theta_std]).T

    df_top1_dist_mean = pd.DataFrame([top1_dist_mean]).T
    df_top2_dist_mean = pd.DataFrame([top2_dist_mean]).T
    df_top1_dist_std = pd.DataFrame([top1_dist_std]).T
    df_top2_dist_std = pd.DataFrame([top2_dist_std]).T

    df = pd.concat(
        [df_counts, 
        df_top1_mean, df_top2_mean,
        df_top1_std, df_top2_std, 
        df_norm_mean, df_norm_std,
        df_top1_cos_theta_mean, df_top2_cos_theta_mean,
        df_top1_cos_theta_std, df_top2_cos_theta_std,
        df_top1_dist_mean, df_top2_dist_mean,
        df_top1_dist_std, df_top2_dist_std], 1)
    df.columns = [
        'counts',
        'top1-logits', 'top2-logits', 'top1-logits-std', 'top2-logits-std',
        'z-norm-mean', 'z-norm-std',
        'top1-cos-theta-mean', 'top2-cos-theta-mean',
        'top1-cos-theta-std', 'top2-cos-theta-std',
        'top1-dist-mean', 'top2-dist-mean', 'top1-dist-std', 'top2-dist-std']
    df.to_csv(save_path)


if __name__ == '__main__':
    args = get_general_args()
    init_wandb(args)
    eval = Evaluator(MLBase(args))

    model = eval.model
    tr_dl = eval.tr_dl
    vl_dl = eval.vl_dl
    print(model.head)
    main(model, tr_dl, vl_dl, args)

